from mmengine.registry import MODELS
from mmengine.model import BaseModule
import numpy as np
import torch.nn as nn, torch
import torch.nn.functional as F
from einops import rearrange
from copy import deepcopy
import torch.distributions as dist
from utils.metric_stp3 import PlanningMetric
import time
import torch.nn.init as init
import matplotlib.pyplot as plt
import torchvision.models as models

import sys, os, pdb

class ForkedPdb(pdb.Pdb):
    """A Pdb subclass that may be used
    from a forked multiprocessing child

    """
    def interaction(self, *args, **kwargs):
        _stdin = sys.stdin
        try:
            sys.stdin = open('/dev/stdin')
            pdb.Pdb.interaction(self, *args, **kwargs)
        finally:
            sys.stdin = _stdin

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            # nn.BatchNorm2d(mid_channels),
            nn.InstanceNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            # nn.BatchNorm2d(out_channels),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True),
            # nn.LayerNorm(),
        )

        self.residual = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.InstanceNorm2d(out_channels),
        ) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        return self.double_conv(x) + self.residual(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

def get_position_encoding(t, h, w, d_model):
    # 生成二维位置编码，并且随着时间推移（t）进行变化
    # 位置编码是基于 token 在 h, w 上的坐标，以及时间步 t 的信息
    
    # 生成二维空间内的相对位置编码
    grid_x, grid_y = torch.meshgrid(torch.arange(h), torch.arange(w))
    grid_x = grid_x.flatten()
    grid_y = grid_y.flatten()
    
    position_encoding = torch.zeros(h * w, d_model)
    # 为了生成 sin 和 cos 编码，需要将位置（grid_x, grid_y）与频率值进行匹配
    # 生成位置编码
    for i in range(0, d_model, 2):
        position_encoding[:, i] = torch.sin(grid_x.float() / (10000 ** (2 * (i // 2) / d_model)))
        position_encoding[:, i + 1] = torch.cos(grid_y.float() / (10000 ** (2 * (i // 2 + 1) / d_model)))
    
    # 为每个位置编码添加时间步 t 的信息
    time_encoding = torch.sin(t * torch.arange(0, d_model).float() / 10000 ** (2 * (torch.arange(0, d_model) / d_model).float()))
    time_encoding = time_encoding.unsqueeze(0)  # [1, d_model]
    
    # 扩展 time_encoding，使其维度与 pos_encoding 匹配
    time_encoding = time_encoding.expand(h * w, -1)  # [h*w, d_model]

    # 将时间信息加到每个位置的编码上
    pos_encoding = position_encoding + time_encoding    # min: -2, max: 2
    pos_encoding = pos_encoding.unsqueeze(1)
    return pos_encoding  # [h*w, 1, d_model]

def compose_triplane_channelwise(feat_maps):
    h_xy, h_xz, h_yz = feat_maps # (H, W), (H, D), (W, D)
    assert h_xy.shape[1] == h_xz.shape[1] == h_yz.shape[1]
    C, H, W = h_xy.shape[-3:]
    D = h_xz.shape[-1]

    newH = max(H, W)
    newW = max(W, D)
    h_xy = F.pad(h_xy, (0, newW - W, 0, newH - H))
    h_xz = F.pad(h_xz, (0, newW - D, 0, newH - H))
    h_yz = F.pad(h_yz, (0, newW - D, 0, newH - W))
    h = torch.cat([h_xy, h_xz, h_yz], dim=1) # (B, 3C, H, W)
    return h, (H, W, D)

def decompose_triplane_channelwise(composed_map, sizes):
    H, W, D = sizes
    C = composed_map.shape[1] // 3
    h_xy = composed_map[:, :C, :H, :W]
    h_xz = composed_map[:, C:2*C, :H, :D]
    h_yz = composed_map[:, 2*C:, :W, :D]
    return h_xy, h_xz, h_yz


class TemporalTransformer(nn.Module):
    def __init__(self, d_model, nhead, num_layers, out_c, out_s, level, type, scale, dim_feedforward=512, prev_steps=4):
        super(TemporalTransformer, self).__init__()

        self.prev_steps = prev_steps
        self.type = type
        self.level = level

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, 
                dim_feedforward=dim_feedforward),
            num_layers=num_layers
        )
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, 
                dim_feedforward=dim_feedforward),
            num_layers=num_layers
        )

        self.ccc = [256*scale, 128*scale, 64*scale, 32*scale, 16*scale]
        self.learn_pos = nn.Parameter(torch.empty(
            out_s[0]*out_s[1]*10, 1, self.ccc[level]), requires_grad=True).cuda()
        init.normal_(self.learn_pos, mean=0.0, std=0.1)

        # static_pos = []
        # for one_time in range(10):
        #     ppp = get_position_encoding(one_time, out_s[0], out_s[1], self.ccc[level])
        #     static_pos.append(ppp)
        # static_pos = torch.cat(static_pos)
        # self.static_pos = static_pos.detach().cuda()

        # for one_time in range(6):
        #     ppp = get_position_encoding(one_time+4, out_s[0], out_s[1], self.ccc[level]).detach().cuda()
        #     self.future_query[one_time*out_s[0]*out_s[1]: (one_time+1)*out_s[0]*out_s[1]] = \
        #         self.future_query[one_time*out_s[0]*out_s[1]: (one_time+1)*out_s[0]*out_s[1]] + ppp
                

        self.d_model = d_model
        self.out_c = out_c
        self.out_s = out_s
        self.fc_out = nn.Linear(d_model, out_c)  # 输出每个时间步的预测
        # self.norm = nn.LayerNorm(out_c)

    def forward(self, x, ttt):
        his, c, h, w = x.shape
        assert h==self.out_s[0] and w==self.out_s[1]
        x = x.permute(0, 2, 3, 1) # his,h,w,c
        x = x.contiguous()
        x = x.view(1, his*h*w, c)
        x = x.permute(1, 0, 2)  # [seq_len, batch_size, feature_dim]
        # x: [144+?, 1, 256]
        
        # 1. encode memory
        x = x + self.learn_pos[:h*w*his]# + self.static_pos[:h*w*his]
        x = x[-self.prev_steps*h*w:, ...] # 只取最新的历史信息
        memory = self.encoder(x)  # [2500, 1, 32]
        # memory: [144+?, 1, 256]

        # 2. decode future
        # query = get_position_encoding(t, h, w, c).to(memory.device) # [h*w, 1, 256]
        # query = self.pos[str(h)][h*w*his: h*w*(his+1)]
        # ForkedPdb().set_trace()
        if False:
            # hp4vis = self.history_pos.mean(dim=-1).squeeze().view(4, h, w)
            # fq4vis = fq4vis.cpu().detach().numpy()
            
            fq4vis = self.future_query.mean(dim=-1).squeeze().view(6, h, w) # 通道就全部合并求平均了
            fq4vis = fq4vis.cpu().detach().numpy()
            # 获取所有图像的最小值和最大值，用于统一colorbar的范围
            vmin = fq4vis.min()
            vmax = fq4vis.max()
            fig, axs = plt.subplots(1, 6, figsize=(12, 2))
            axs = axs.flatten()
            for i in range(6):
                im = axs[i].imshow(fq4vis[i], cmap='viridis', aspect='auto', vmin=vmin, vmax=vmax)
                axs[i].set_title(f"Time step {i+1}")
                axs[i].axis('off')  # Hide axis for clarity
            fig.colorbar(im, ax=axs[-1], orientation='vertical', fraction=0.05, pad=0.04)
            plt.tight_layout()
            save_path = "./train_results/vis_query/"
            os.makedirs(save_path, exist_ok=True)
            plt.savefig(os.path.join(save_path, f"query_{self.type}_{self.level}.jpg"))

        assert ttt == his
        query = self.learn_pos[h*w*his: h*w*(his+1)]
        # ssss = self.static_pos[h*w*his: h*w*(his+1)]
        output = self.decoder(
            query,# + ssss, 
            memory
        )  # num_token, 1, token_dim
        output = output.permute(1, 0, 2) # 1, num_token, token_dim

        # output = self.decoder(
        #     self.future_query + self.static_pos[self.out_s[0]*self.out_s[1]*his:], 
        #     memory
        # )  # num_token, 1, token_dim
        # 3. 
        output = self.fc_out(output)  
        # output = self.norm(output)
        output = output.permute(0, 2, 1).contiguous()
        # ForkedPdb().set_trace()
        output = output.view(-1, self.out_c, self.out_s[0], self.out_s[1])  # [1, 256, 6, 6]

        return output

class SELayer(nn.Module):
    def __init__(self, channel, reduction=4):
        super(SELayer, self).__init__()
        # Squeeze: 用全局平均池化来压缩空间维度
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # Excitation: 两个全连接层形成一个瓶颈结构
        self.fc1 = nn.Linear(channel, channel // reduction)
        self.fc2 = nn.Linear(channel // reduction, channel)
        # 激活函数和Sigmoid激活来生成权重
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # 1. squeeze: 对输入进行全局平均池化
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        # 2. excitation: 使用全连接层和Sigmoid函数生成通道注意力
        y = F.relu(self.fc1(y))
        y = self.fc2(y)
        y = self.sigmoid(y).view(b, c, 1, 1)
        # 3. 重标定输入特征图
        return x * y.expand_as(x)

def replace_bn(module, norm_layer, **kwargs):
    """通用替换函数"""
    for name, child in module.named_children():
        if isinstance(child, nn.BatchNorm2d):
            # 创建新的归一化层
            new_norm = norm_layer(
                num_features=child.num_features,
                affine=child.affine,
                **kwargs
            )
            setattr(module, name, new_norm)
        else:
            replace_bn(child, norm_layer, **kwargs)

class TrajNet(nn.Module):
    def __init__(self, n_channels, reduction):
        super(TrajNet, self).__init__()

        self.resnet_base = models.resnet18()
        self.resnet_delt = models.resnet18()

        self.resnet_base.conv1 = nn.Conv2d(24, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
        self.resnet_delt.conv1 = nn.Conv2d(24, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
        replace_bn(self.resnet_base, nn.InstanceNorm2d)
        replace_bn(self.resnet_delt, nn.InstanceNorm2d)
        self.resnet_base.fc = nn.Sequential(
            nn.Linear(self.resnet_base.fc.in_features, 100),
            nn.Sigmoid() 
        )
        self.resnet_delt.fc = nn.Sequential(
            nn.Linear(self.resnet_delt.fc.in_features, 100),
            nn.ReLU()
        )
        self.delt_prev = nn.Sequential(
            nn.Linear(100, 100),
            nn.Sigmoid()
        )
        self.delt_next = nn.Sequential(
            nn.Linear(100, 100),
            nn.Sigmoid()
        )

        # 检查参数梯度状态
        # for name, param in self.resnet_delt.named_parameters():
        #     print(f"{name}: requires_grad={param.requires_grad}")

        self.last_fc = nn.Sequential(nn.Linear(6, 50), nn.ReLU(inplace=True))
        self.now_fc = nn.Sequential(nn.Linear(100, 50), nn.ReLU(inplace=True))

        self.tf_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=50, nhead=2, dim_feedforward=128),
            num_layers=2
        )
        self.tf_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=50, nhead=2, dim_feedforward=128),
            num_layers=2
        )

        self.learn_pos = nn.Parameter(torch.empty(
            10, 1, 50), requires_grad=True).cuda()
        init.normal_(self.learn_pos, mean=0.0, std=0.1)

        self.flatten = nn.Flatten(start_dim=1)
        self.traj_out = nn.Sequential(
            nn.Linear(50, 25),
            nn.ReLU(inplace=True),
            nn.Linear(25, 6),
        )

        self.idx = [
            [0, 1],
            [1, 2],
            [2, 3],
            [3, 4],
            [4, 5],
            [5, 6]
        ]

    def forward(self, 
        xys_history, xzs_history, yzs_history,
        xy_delta, xz_delta, yz_delta,
        xys_future, xzs_future, yzs_future,
        metas
    ):
        
        # base
        xys_base = torch.cat([xys_history.unsqueeze(0), xys_future])
        xzs_base = torch.cat([xzs_history.unsqueeze(0), xzs_future])
        yzs_base = torch.cat([yzs_history.unsqueeze(0), yzs_future])
        h_base, _ = compose_triplane_channelwise([xys_base, xzs_base, yzs_base]) # [7, 24, 100, 100]
        h_base = torch.stack([h_base[[i, j]] for i, j in self.idx], dim=0)
        b, t, c, h, w = h_base.shape
        # h_base = h_base.view(b, t*c, h, w).detach()
        h_base = h_base.view(b*t, c, h, w).detach() # (12, 24, 100, 100)
        h_base = self.resnet_base(h_base)    # (12, 1000)

        # delta
        h_delta, _ = compose_triplane_channelwise([xy_delta, xz_delta, yz_delta]) # [6, 24, 100, 100]
        h_delta = self.resnet_delt(h_delta)    # (6, 100)
        h_delta_prev = self.delt_prev(h_delta)    # (6, 100)
        h_delta_next = self.delt_next(h_delta)    # (6, 100)

      
        # last time pose
        gt_rel_pose, gt_mode = [], []
        for meta in metas:
            gt_rel_pose.append(meta['rel_poses'])
            gt_mode.append(meta['gt_mode'])

        gt_rel_pose = torch.tensor(np.asarray(gt_rel_pose), dtype=torch.float32) # B, F=6, 2
        gt_mode = torch.tensor(np.asarray(gt_mode), dtype=torch.float32).transpose(1,2) # B, F, M -> B, M, F
        gt_rel_pose = gt_rel_pose.unsqueeze(1).repeat(1, 3, 1, 1) # B, M, F, 2
        gt_rel_pose = gt_rel_pose.transpose(1,2)    # B, F, M, 2
        # ForkedPdb().set_trace()
        
        # pred traj
        all_h = torch.cat((
            (h_delta_prev*h_base[::2]).unsqueeze(1), 
            (h_delta_next*h_base[1::2]).unsqueeze(1), 
            (h_delta).unsqueeze(1)
        ), dim=1)   # (6,3,100)
        # all_h = all_h.unsqueeze(2)  # (6,3,1,100)
        # all_h = self.avg_pool(all_h)    # (6,3,1,50)
        all_h = torch.mean(all_h, dim=1)    # (6,1,100)
        # ForkedPdb().set_trace()
        all_h = self.flatten(all_h) # (6,100)
        all_h = self.now_fc(all_h)  # (6, 50)
        # all_h = self.traj_out(all_h)
        # pred_trajs = all_h.reshape(all_h.shape[0], 3, 2)

        # 
        last_trajs = gt_rel_pose[:, :4, :, :].cuda()  # (1,4,3,2)
        pred_trajs = []
        for ttt in range(all_h.shape[0]):

            last_h = last_trajs[:, -4:, :, :].flatten(start_dim=2)  # (1,4,6)
            last_h = self.last_fc(last_h)       # (1,4,50)
            last_h = last_h.permute(1, 0, 2)    # (4,1,50)
            last_h = last_h + self.learn_pos[ttt: ttt+4]

            memory = self.tf_encoder(last_h)    # (4,1,50)

            query = self.learn_pos[ttt+4].unsqueeze(0)  # (t=1,b=1,50)
            query = query + all_h[ttt].unsqueeze(0).unsqueeze(0)
            output = self.tf_decoder(query, memory) # (1,1,50)
            output = output.squeeze(1)
            # ForkedPdb().set_trace()

            new_traj = self.traj_out(output)   # (1,6)
            new_traj = new_traj.reshape(1, 3, 2)

            pred_trajs.append(new_traj)
            last_trajs = torch.cat((last_trajs, new_traj.unsqueeze(0)), dim=1)
            # ForkedPdb().set_trace()

        pred_trajs = torch.cat(pred_trajs)
        # pred_trajs = pred_trajs.reshape(pred_trajs.shape[0], 3, 2)
        # ForkedPdb().set_trace()

        return pred_trajs


class UNetTF(nn.Module):
    def __init__(self, n_channels, n_classes, prev_steps, out_s, type, scale=2, bilinear=False):
        super(UNetTF, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.out_s = out_s
        self.type = type
        self.scale = scale
        self.prev_steps = prev_steps # 用历史4帧预测未来

        self.inc = (DoubleConv(n_channels, 16*scale))
        self.down1 = (Down(16*scale, 32*scale))
        self.down2 = (Down(32*scale, 64*scale))
        self.down3 = (Down(64*scale, 128*scale))
        self.down4 = (Down(128*scale, 256*scale))
        self.up1 = (Up(256*scale, 128*scale, bilinear))
        self.up2 = (Up(128*scale, 64*scale, bilinear))
        self.up3 = (Up(64*scale, 32*scale, bilinear))
        self.up4 = (Up(32*scale, 16*scale, bilinear))

        # Transformer Blocks for autoregressive prediction
        self.transformer1 = TemporalTransformer(
            16*scale, out_c=16*scale, out_s=out_s[-5], level=4, nhead=1, num_layers=4, type=type, scale=scale, prev_steps=prev_steps)
        self.transformer2 = TemporalTransformer(
            32*scale, out_c=32*scale, out_s=out_s[-4], level=3, nhead=2, num_layers=4, type=type, scale=scale, prev_steps=prev_steps)
        self.transformer3 = TemporalTransformer(
            64*scale, out_c=64*scale, out_s=out_s[-3], level=2, nhead=4, num_layers=4, type=type, scale=scale, prev_steps=prev_steps)
        self.transformer4 = TemporalTransformer(
            128*scale, out_c=128*scale, out_s=out_s[-2], level=1, nhead=8, num_layers=4, type=type, scale=scale, prev_steps=prev_steps)
        self.transformer5 = TemporalTransformer(
            256*scale, out_c=256*scale, out_s=out_s[-1], level=0, nhead=16, num_layers=4, type=type, scale=scale, prev_steps=prev_steps)
        
     
        self.norm = nn.InstanceNorm2d(n_classes)
        # self.act = SiLU()
        self.dropout = nn.Dropout2d(0.1)
        
        self.outdelta = (OutConv(16*scale, n_classes))
        self.shortcut = nn.Conv2d(n_channels, n_classes, kernel_size=1, stride=1, padding=0).cuda()

        # self.outtraj = nn.Sequential(
        #     # nn.Conv2d(n_classes*3, n_classes*3, kernel_size=7, stride=3), 
        #     nn.Conv2d(n_classes*1, n_classes*1, kernel_size=7, stride=3), 
        #     nn.ReLU(inplace=True),
        #     nn.Dropout(p=0.1),

        #     # nn.Conv2d(n_classes*3, 1, kernel_size=1), 
        #     nn.Conv2d(n_classes*1, 1, kernel_size=1), 
        #     nn.ReLU(inplace=True),
        #     nn.Dropout(p=0.1),

        #     # SiLU(),
        #     # nn.AdaptiveAvgPool2d((10, 10)),
        #     nn.AdaptiveAvgPool2d((5, 5)),
        #     nn.Flatten(start_dim=1), 
        #     nn.Dropout(p=0.1),
        #     # SiLU(),
        #     # nn.Linear(10*10, 6),
        #     nn.Linear(5*5, 6),
        # )

    def forward(self, x):
        # x: [16, 8, 100, 100]
        # x: [16, 8, 100, 8]
        # x: [16, 8, 100, 8]
        # x = x[:self.prev_steps]
        x = self.dropout(x)

        history_x = x.clone()


        x1 = self.inc(x[:self.prev_steps])
        # x1 = self.inc(x)    # [4, 16, 100, 100]
        x2 = self.down1(x1) # [4, 32, 50, 50]
        x3 = self.down2(x2) # [4, 64, 25, 25]
        x4 = self.down3(x3) # [4, 128, 12, 12]
        x5 = self.down4(x4) # [4, 256, 6, 6]

        print("x1:", x1.shape)
        print("x2:", x2.shape)
        print("x3:", x3.shape)
        print("x4:", x4.shape)
        print("x5:", x5.shape)

        x1 = F.adaptive_avg_pool2d(x1, self.out_s[0])
        x2 = F.adaptive_avg_pool2d(x2, self.out_s[1])
        # x3 = F.adaptive_avg_pool2d(x3, self.out_s[1])
        # x4 = F.adaptive_avg_pool2d(x4, self.out_s[1])
        # x5 = F.adaptive_avg_pool2d(x5, self.out_s[1])
        
        history_x1 = x1.clone()
        history_x2 = x2.clone()
        history_x3 = x3.clone()
        history_x4 = x4.clone()
        history_x5 = x5.clone()
        # assert history_x3.shape[0] == self.prev_steps

        pred_occs, pred_trajs = [], []
        pred_deltas = []
        # pred_ys = []

        teacher_forcing_times_occ = 0
        teacher_forcing_times_traj = 0
        future_pred_times = x[self.prev_steps:].shape[0]
        for ttt in range(future_pred_times):
            # Apply Transformer self-attention for autoregressive prediction
            next_x1 = self.transformer1(history_x1, self.prev_steps+ttt) 
            next_x2 = self.transformer2(history_x2, self.prev_steps+ttt) 
            next_x3 = self.transformer3(history_x3, self.prev_steps+ttt)  
            next_x4 = self.transformer4(history_x4, self.prev_steps+ttt)  
            next_x5 = self.transformer5(history_x5, self.prev_steps+ttt)   # [1, 256, 6, 6]
            # ForkedPdb().set_trace()

            # agg multi-scale motions
            # 将多尺度的motion动态特征和最近的历史合并，得到新的多尺度静态特征
            new_x1 = history_x1[-1].unsqueeze(0) + next_x1
            new_x2 = history_x2[-1].unsqueeze(0) + next_x2
            new_x3 = history_x3[-1].unsqueeze(0) + next_x3
            new_x4 = history_x4[-1].unsqueeze(0) + next_x4
            new_x5 = history_x5[-1].unsqueeze(0) + next_x5
            
            # 历史汇总
            history_x1 = torch.cat([history_x1, new_x1], dim=0)
            history_x2 = torch.cat([history_x2, new_x2], dim=0)
            history_x3 = torch.cat([history_x3, new_x3], dim=0)
            history_x4 = torch.cat([history_x4, new_x4], dim=0)
            history_x5 = torch.cat([history_x5, new_x5], dim=0) # [4+?, 256, 6, 6]
            # assert history_x3.shape[0] == self.prev_steps


            next_x1 = F.interpolate(next_x1, size=(100, 100 if self.type=="xy" else 16), mode='bilinear', align_corners=False)
            next_x2 = F.interpolate(next_x2, size=(50, 50 if self.type=="xy" else 8), mode='bilinear', align_corners=False)
            # next_x3 = F.interpolate(next_x3, size=(25, 25), mode='bilinear', align_corners=False)
            # next_x4 = F.interpolate(next_x4, size=(12, 12), mode='bilinear', align_corners=False)
            # next_x5 = F.interpolate(next_x5, size=(6, 6), mode='bilinear', align_corners=False)

            pred_y = self.up1(next_x5, next_x4)
            pred_y = self.up2(pred_y, next_x3)
            pred_y = self.up3(pred_y, next_x2)
            pred_y = self.up4(pred_y, next_x1)  # [1, 32, 100, 100]

            # 这是预测的occ变化量
            pred_delta = self.outdelta(pred_y)   
            pred_deltas.append(pred_delta.clone())

            # if np.random.rand() < 0.3 and \
            #     teacher_forcing_times<future_pred_times and \
            #     ttt!=future_pred_times-1:  # random teacher Force!!!
            #     # pred_occ = self.shortcut(x[self.prev_steps:][ttt]) + pred_y
            #     pred_occ = history_x[self.prev_steps:][ttt].unsqueeze(0)
            #     teacher_forcing_times += 1
            #     # print("TT:", pred_occ.shape, pred_occ.min(), pred_occ.max())
            # else:
            # ForkedPdb().set_trace()
            # 把上一步的detach下
            if len(pred_occs) == 0:
                assert ttt == 0
                pred_occ = self.shortcut(history_x[self.prev_steps-1].clone().detach()) + pred_delta 
            else:
                if self.training and np.random.rand() < 0.1 and \
                    teacher_forcing_times_occ<future_pred_times:
                    last_occ = history_x[self.prev_steps:][ttt].unsqueeze(0).clone().detach()
                    teacher_forcing_times_occ += 1
                else:
                    last_occ = pred_occs[-1].clone().detach()                
                pred_occ = self.shortcut(last_occ) + pred_delta

                # pred_occ = self.shortcut(pred_occs[-1]) + pred_delta
            pred_occ = (self.norm(pred_occ) * 0.5).tanh()
            # print("SS:", pred_occ.shape, pred_occ.min(), pred_occ.max())


            # pred-occ
            # pred_occ = self.dropout(pred_occ)
            pred_occs.append(pred_occ)

            # pred-traj
            # pred_delta = pred_delta.clone().detach()
            # if len(pred_occs) < 2:
            #     if self.training and np.random.rand() < 0.3 and \
            #         teacher_forcing_times_traj<future_pred_times:
            #         prev_time_occ = history_x[self.prev_steps-1].clone().detach().unsqueeze(0)
            #         next_time_occ = history_x[self.prev_steps].clone().detach().unsqueeze(0)
            #         teacher_forcing_times_traj += 1
            #     else:
            #         prev_time_occ = history_x[self.prev_steps-1].clone().detach().unsqueeze(0)
            #         next_time_occ = pred_occs[-1].clone().detach()
            # else:
            #     if self.training and np.random.rand() < 0.3 and \
            #         teacher_forcing_times_traj < future_pred_times:
            #         prev_time_occ = history_x[self.prev_steps:][ttt-1].clone().detach().unsqueeze(0)
            #         next_time_occ = history_x[self.prev_steps:][ttt].clone().detach().unsqueeze(0)
            #         teacher_forcing_times_traj += 1
            #     else:
            #         prev_time_occ = pred_occs[-2]#.clone().detach()
            #         next_time_occ = pred_occs[-1].clone().detach()
            # pred_delta_plus = torch.cat((
            #     prev_time_occ,
            #     pred_delta,
            #     next_time_occ
            # ), dim=1)
            # pred_traj = self.outtraj(pred_delta_plus)    # 
            
            # ForkedPdb().set_trace()
            # pred_traj = self.outtraj(pred_delta)    # 
            # pred_trajs.append(pred_traj)
            # pred_trajs.append(None)


        pred_occs = torch.cat(pred_occs, dim=0) # (6, 8*3, 100, 100)
        # pred_trajs = torch.cat(pred_trajs, dim=0) # (6, 6)
        # pred_trajs = pred_trajs.reshape(
        #     pred_trajs.shape[0], 3, 2)  # (6, 3, 2)
        # ForkedPdb().set_trace()
        pred_deltas = torch.cat(pred_deltas, dim=0)

        # return pred_occs, pred_trajs
        return pred_occs, pred_deltas


class SiLU(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

@MODELS.register_module()
class TransHexplane(BaseModule):
    def __init__(self, 
                 num_frames=10, offset=1, prev_steps=4, num_classes=18,
                 triplane_cfg=None, encoder_cfg=None, decoder_cfg=None, 
                 pose_encoder=None, pose_decoder=None,
                 pose_actor=None, give_hiddens=False, delta_input=False, without_all=False):
        super().__init__()
        self.num_frames = num_frames
        self.prev_steps = prev_steps
        self.offset = offset
        self.num_cls = num_classes

        # self.vae = MODELS.build(vae)
        # self.transformer = MODELS.build(transformer)
        
        
        if pose_encoder is not None:
            self.pose_encoder = MODELS.build(pose_encoder)
        if pose_decoder is not None:
            self.pose_decoder = MODELS.build(pose_decoder)
        if pose_actor is not None:
            self.pose_actor = MODELS.build(pose_actor)
        self.triplane_net = MODELS.build(triplane_cfg)


        # !!!!!!!!!
        # for param in self.triplane_net.parameters():
        #     param.requires_grad = False
        del self.triplane_net.encoder
        del self.triplane_net.class_embeds
        # self.decoder = MODELS.build(decoder_cfg)

        self.give_hiddens = give_hiddens
        self.delta_input = delta_input
        self.planning_metric = None
        self.without_all = without_all


        self.transnet_xy = UNetTF(8, 8, prev_steps=prev_steps, scale=1, type="xy",
                                    out_s=[[100//4,100//4], [50//2,50//2], [25,25], [12,12], [6,6]])
        self.transnet_xz = UNetTF(8, 8, prev_steps=prev_steps, scale=1, type="xz",
                                    out_s=[[100//4,16], [50//2,8], [25,4], [12,2], [6,1]])
        self.transnet_yz = UNetTF(8, 8, prev_steps=prev_steps, scale=1, type="yz",
                                    out_s=[[100//4,16], [50//2,8], [25,4], [12,2], [6,1]])
        
        self.trajnet = TrajNet(8*3, 4)
        # self.transnet_xyz = UNetTF(8*3, 8*3, prev_steps=prev_steps, type="all", scale=1,
        #                             out_s=[[100//4,100//4], [50//2,50//2], [25,25], [12,12], [6,6]])

        ks = 11
        # self.fusenet_xyz = nn.Sequential(
        #     # nn.GroupNorm(num_groups=3, num_channels=out_channels, eps=1e-6, affine=True),
        #     SiLU(),
        #     nn.Dropout(p=0.2),
        #     nn.Conv2d(8*3, 8*3, kernel_size=ks, stride=1, padding=(ks - 1)//2)
        # )
        # self.shortcut = nn.Conv2d(8*3, 8*3, kernel_size=1, stride=1, padding=0)

        self.norm = nn.InstanceNorm2d(8)

    def forward(self, input_occs, querys, xys, xzs, yzs,
                xyz_labels=None, xyz_centers=None, metas=None):
        if hasattr(self, 'pose_encoder'):
            if self.training:
                return self.forward_train_with_plan(x, metas)
            else:
                return self.forward_inference_with_plan(x, metas)
        if self.training:
            return self.forward_train(input_occs, querys, xys, xzs, yzs, xyz_labels, xyz_centers, metas)
        else:
            return self.forward_inference(input_occs, querys, xys, xzs, yzs, xyz_labels, xyz_centers, metas)
        

    def forward_train(self, 
            input_occs, querys, 
            xys, xzs, yzs, 
            xyz_labels, xyz_centers, metas):

        bs, F, H, W, D = input_occs.shape
        # [1, 16, 200, 200, 16]
        assert F == self.num_frames + self.offset
        output_dict = {}

        # [111]
        # h, _ = compose_triplane_channelwise([xys, xzs, yzs]) # [10, 24, 100, 100]
        # h, pred_trajs = self.transnet_xyz(h)    # [6, 24, 100, 100]
        # xys_future, xzs_future, yzs_future = decompose_triplane_channelwise(h, (H, W, D))

        # [222]
        xys_future, xy_delta = self.transnet_xy(xys)    # [12, 8, 100, 100]
        xzs_future, xz_delta = self.transnet_xz(xzs)    # [12, 8, 100, 8]
        yzs_future, yz_delta = self.transnet_yz(yzs)    # [12, 8, 100, 8]
        # pred_trajs = (xy_pred_trajs+xz_pred_trajs+yz_pred_trajs)/3
        pred_trajs = self.trajnet(
            xys[self.prev_steps], xzs[self.prev_steps], yzs[self.prev_steps],
            xy_delta, xz_delta, yz_delta,
            xys_future, xzs_future, yzs_future,
            metas
        )

        # with torch.no_grad():
        preds = self.triplane_net.forward_decoder(
            [xys_future, xzs_future, yzs_future], querys[self.prev_steps:])
        # [12, 200000, 18]
        softmax_preds = torch.nn.functional.softmax(preds, dim=2)

        empty_label = 0.
        pred_logits = torch.full(
            (preds.shape[0], 200, 200, 16, self.num_cls), 
            fill_value=empty_label, device=preds.device)    # [6, 200, 200, 16, 18]
        pred_output = torch.full(
            (preds.shape[0], 200, 200, 16, self.num_cls), 
            fill_value=empty_label, device=preds.device)
        pred_logits[:, :, :, :, -1] = 1 # 保证最后17类别位置是1，其他为0，这样默认17是非占据的
        pred_output[:, :, :, :, -1] = 1
        for i in range(softmax_preds.shape[0]):
            pred_output[i, 
                xyz_centers[i+self.prev_steps, :, 0], 
                xyz_centers[i+self.prev_steps, :, 1], 
                xyz_centers[i+self.prev_steps, :, 2], :] = softmax_preds[i]
            pred_logits[i, 
                xyz_centers[i+self.prev_steps, :, 0], 
                xyz_centers[i+self.prev_steps, :, 1], 
                xyz_centers[i+self.prev_steps, :, 2], :] = preds[i]

        # only for debug
        # pred_logits = pred_logits.unsqueeze(0) 
        # pred = pred_logits.argmax(dim=-1).detach().cuda()
        # output_dict['sem_pred'] = pred  # [10, 200, 200, 16]
        # pred_iou = deepcopy(pred)
        # pred_iou[pred_iou!=17] = 1
        # pred_iou[pred_iou==17] = 0
        # output_dict['iou_pred'] = pred_iou

        # ForkedPdb().set_trace()
        # 处理traj的GT
        output_metas = [{
            'rel_poses': metas[0]['rel_poses'][self.prev_steps:],#[np.newaxis, ...],
            'gt_mode': metas[0]['gt_mode'][self.prev_steps:],#[np.newaxis, ...],
        }]

        output_dict.update({
            'preds': preds,     # [12, 200000, 18]
            'pred_output': pred_output,  # [12, 200, 200, 16, 18]
            'hexplane': [xys_future, xzs_future, yzs_future],
            'hexplane_mask': None,
            # 'hexplane_mask': [xys_mask, xzs_mask, yzs_mask]

            'pose_decoded': pred_trajs.unsqueeze(0),  # [1, 6, 3, 2]
            'output_metas': output_metas
            # output_dict['output_metas'] = output_metas  # 
        })

        return output_dict
        
    
    def forward_inference(self, 
            input_occs, querys, 
            xys, xzs, yzs, 
            xyz_labels, xyz_centers, metas):

        bs, F, H, W, D = input_occs.shape
        # [1, 16, 200, 200, 16]
        # assert F == self.num_frames + self.offset
        output_dict = {}

        # [111]
        # h, _ = compose_triplane_channelwise([xys, xzs, yzs]) # [10, 24, 100, 100]
        # h, pred_trajs = self.transnet_xyz(h)    # [6, 24, 100, 100]
        # xys_future, xzs_future, yzs_future = decompose_triplane_channelwise(h, (H, W, D))

        # [222]
        t0 = time.time()
        xys_future, xy_delta = self.transnet_xy(xys)    # [12, 8, 100, 100]
        xzs_future, xz_delta = self.transnet_xz(xzs)    # [12, 8, 100, 8]
        yzs_future, yz_delta = self.transnet_yz(yzs)    # [12, 8, 100, 8]
        # pred_trajs = (xy_pred_trajs+xz_pred_trajs+yz_pred_trajs)/3
        pred_trajs = self.trajnet(
            xys[self.prev_steps], xzs[self.prev_steps], yzs[self.prev_steps],
            xy_delta, xz_delta, yz_delta,
            xys_future, xzs_future, yzs_future,
            metas,
        )

        # with torch.no_grad():
        preds = self.triplane_net.forward_decoder(
            [xys_future, xzs_future, yzs_future], querys[self.prev_steps:])
        softmax_preds = torch.nn.functional.softmax(preds, dim=2)
        t1 = time.time()

        empty_label = 0.
        pred_logits = torch.full(
            (preds.shape[0], 200, 200, 16, self.num_cls), 
            fill_value=empty_label, device=preds.device)
        pred_output = torch.full(
            (preds.shape[0], 200, 200, 16, self.num_cls), 
            fill_value=empty_label, device=preds.device)
        pred_logits[:, :, :, :, -1] = 1 # 保证最后17类别位置是1，其他为0，这样默认17是非占据的
        pred_output[:, :, :, :, -1] = 1
        for i in range(softmax_preds.shape[0]):
            pred_output[i, 
                xyz_centers[i+self.prev_steps, :, 0], 
                xyz_centers[i+self.prev_steps, :, 1], 
                xyz_centers[i+self.prev_steps, :, 2], :] = softmax_preds[i]
            pred_logits[i, 
                xyz_centers[i+self.prev_steps, :, 0], 
                xyz_centers[i+self.prev_steps, :, 1], 
                xyz_centers[i+self.prev_steps, :, 2], :] = preds[i]

        # 处理traj的GT
        output_metas = [{
            'rel_poses': metas[0]['rel_poses'][self.prev_steps:],#[np.newaxis, ...],
            'gt_mode': metas[0]['gt_mode'][self.prev_steps:],#[np.newaxis, ...],
        }]

        output_dict.update({
            'preds': preds,     # [12, 200000, 18]
            'pred_output': pred_output,  # [12, 200, 200, 16, 18]
            'hexplane': [xys_future, xzs_future, yzs_future],
            'hexplane_mask': None,
            # 'hexplane_mask': [xys_mask, xzs_mask, yzs_mask]

            'pose_decoded': pred_trajs.unsqueeze(0),  # [1, 6, 3, 2]
            'output_metas': output_metas
            # output_dict['output_metas'] = output_metas  # 
        })


        if not self.training:
            pred = pred_output.unsqueeze(0).argmax(dim=-1).detach().cuda()
            output_dict['sem_pred'] = pred  # [1, 6, 200, 200, 16]
            pred_iou = deepcopy(pred)
            
            pred_iou[pred_iou!=17] = 1
            pred_iou[pred_iou==17] = 0
            output_dict['iou_pred'] = pred_iou
            
            # part2
            pred_ego_fut_trajs = output_dict['pose_decoded']    # (1,6,2)
            gt_mode = torch.tensor([meta['gt_mode'] for meta in output_dict['output_metas']])   # (1,6,3)
            bs, num_frames, num_modes, _ = pred_ego_fut_trajs.shape
            pred_ego_fut_trajs = pred_ego_fut_trajs[gt_mode.bool()].reshape(bs, num_frames, 2)
            pred_ego_fut_trajs = torch.cumsum(pred_ego_fut_trajs, dim=1).cpu()
            gt_ego_fut_trajs = torch.tensor([meta['rel_poses'] for meta in output_dict['output_metas']])
            gt_ego_fut_trajs = torch.cumsum(gt_ego_fut_trajs, dim=1).cpu()
            assert len(metas) == 1, f'len(metas): {len(metas)}'
            gt_bbox = metas[0]['gt_bboxes_3d']
            gt_attr_labels = torch.tensor(metas[0]['attr_labels'])
            fut_valid_flag = torch.tensor(metas[0]['fut_valid_flag'])
            # import pdb;pdb.set_trace()
            metric_stp3 = self.compute_planner_metric_stp3(
                pred_ego_fut_trajs, # (1,6,2)
                gt_ego_fut_trajs,   # (1,6,2)
                gt_bbox, gt_attr_labels[None], True)
            
            output_dict['metric_stp3'] = metric_stp3


        output_dict['time'] = {
            'total':t1-t0, 'per_frame':(t1-t0)/6}
        return output_dict

    def compute_planner_metric_stp3(
        self,
        pred_ego_fut_trajs,
        gt_ego_fut_trajs,
        gt_agent_boxes,
        gt_agent_feats,
        fut_valid_flag
    ):
        """Compute planner metric for one sample same as stp3"""
        metric_dict = {
            'plan_L2_1s':0,
            'plan_L2_2s':0,
            'plan_L2_3s':0,
            'plan_obj_col_1s':0,
            'plan_obj_col_2s':0,
            'plan_obj_col_3s':0,
            'plan_obj_box_col_1s':0,
            'plan_obj_box_col_2s':0,
            'plan_obj_box_col_3s':0,
            'plan_L2_1s_single':0,
            'plan_L2_2s_single':0,
            'plan_L2_3s_single':0,
            'plan_obj_col_1s_single':0,
            'plan_obj_col_2s_single':0,
            'plan_obj_col_3s_single':0,
            'plan_obj_box_col_1s_single':0,
            'plan_obj_box_col_2s_single':0,
            'plan_obj_box_col_3s_single':0,
            
        }
        metric_dict['fut_valid_flag'] = fut_valid_flag
        future_second = 3
        assert pred_ego_fut_trajs.shape[0] == 1, 'only support bs=1'
        if self.planning_metric is None:
            self.planning_metric = PlanningMetric()
        segmentation, pedestrian = self.planning_metric.get_label(
            gt_agent_boxes, gt_agent_feats)
        occupancy = torch.logical_or(segmentation, pedestrian)
        for i in range(future_second):
            if fut_valid_flag:
                cur_time = (i+1)*2
                traj_L2 = self.planning_metric.compute_L2(
                    pred_ego_fut_trajs[0, :cur_time].detach().to(gt_ego_fut_trajs.device),
                    gt_ego_fut_trajs[0, :cur_time]
                )
                traj_L2_single = self.planning_metric.compute_L2(
                    pred_ego_fut_trajs[0, cur_time-1:cur_time].detach().to(gt_ego_fut_trajs.device),
                    gt_ego_fut_trajs[0, cur_time-1:cur_time]
                )
                obj_coll, obj_box_coll = self.planning_metric.evaluate_coll(
                    pred_ego_fut_trajs[:, :cur_time].detach(),
                    gt_ego_fut_trajs[:, :cur_time],
                    occupancy)
                obj_coll_single, obj_box_coll_single = self.planning_metric.evaluate_coll(
                    pred_ego_fut_trajs[:, cur_time-1:cur_time].detach(),
                    gt_ego_fut_trajs[:, cur_time-1:cur_time],
                    occupancy[:, cur_time-1:cur_time])
                metric_dict['plan_L2_{}s'.format(i+1)] = traj_L2
                metric_dict['plan_L2_{}s_single'.format(i+1)] = traj_L2_single
                metric_dict['plan_obj_col_{}s'.format(i+1)] = obj_coll.mean().item()
                metric_dict['plan_obj_box_col_{}s'.format(i+1)] = obj_box_coll.mean().item()
                metric_dict['plan_obj_col_{}s_single'.format(i+1)] = obj_coll_single.item()
                metric_dict['plan_obj_box_col_{}s_single'.format(i+1)] = obj_box_coll_single.item()
                
                
            else:
                metric_dict['plan_L2_{}s'.format(i+1)] = 0.0
                metric_dict['plan_L2_{}s_single'.format(i+1)] = 0.0
                metric_dict['plan_obj_col_{}s'.format(i+1)] = 0.0
                metric_dict['plan_obj_box_col_{}s'.format(i+1)] = 0.0
            
        return metric_dict
    